백준 17472번 - 다리 만들기2

문제링크

풀이과정

  1. 각 섬을 노드로 표현
  2. 섬끼리의 최소 거리를 노드간의 가중치로 표현
  3. 최소 신장 트리 작성
    즉, 연결성분 -> 그래프표현 -> 최소신장트리 과정으로 문제를 풀었다.

각 섬을 노드로 표현

좌표 평면에 1로 이루어진 섬들을 그룹핑지어야했다.
연결성분 검사(connected component)
그래프가 아닌 이차원배열의 자료구조를 그룹핑 짓는게 어려웠다.

for i in range(n):
    for j in range(m):
        if arr[i][j] == 1:
            group += 1
            stack.append((i, j, group))
        while stack:
            x, y, group = stack.pop()
            arr[x][y] = group
            for dx, dy in d:
                newx, newy = x + dx, y + dy
                if inRange(newx, newy) and arr[newx][newy] == 1:
                    stack.append((newx, newy, group))

  1. 2차원 배열 순회 하면서 group이 정해지지 않은 칸을 stack에 추가
  2. 동서남북 인접 칸들 중 섬인 지역을 stack에 추가

생각해보니
연결성분 -> 그래프 대신에
그래프 -> 연결성분 방법도 가능할 것 같다.

다시 생각해보니 섬들을 연결성분으로 묶기 전에 그래프로 표현하면 3차원으로 표현해야해서 더 복잡할 것 같다. 또한 섬들의 크기 등이 필요한게 아니라서 연결성분으로 먼저 묶는데 편한 것 같다.

다리를 간선으로 표현

가로 세로로 순회하면서 0으로 바뀌는 부분을 확인하여 거리를 간선을 바꾸어 줬다.

def addEdge(a, b, d):
    if d < 2:
        return
    if graph[a].get(b) == None:
        graph[a][b] = d
    else:
        graph[a][b] = min(graph[a][b], d)
    if graph[b].get(a) == None:
        graph[b][a] = d
    else:
        graph[b][a] = min(graph[b][a], d)


for i in range(n):
    currentGroup = 0
    distance = 0
    for j in range(m):
        if arr[i][j] == 0:
            distance += 1
        elif arr[i][j] == currentGroup:
            distance = 0
        elif arr[i][j] != currentGroup:
            if currentGroup != 0:
                addEdge(currentGroup, arr[i][j], distance)
            currentGroup = arr[i][j]
            distance = 0

for j in range(m):
    currentGroup = 0
    distance = 0
    for i in range(n):
        if arr[i][j] == 0:
            distance += 1
        elif arr[i][j] == currentGroup:
            distance = 0
        elif arr[i][j] != currentGroup:
            if currentGroup != 0:
                addEdge(currentGroup, arr[i][j], distance)
            currentGroup = arr[i][j]
            distance = 0

최소신장 트리

그래프로 저장해서 크루스칼보다 prim이 알고리즘이 좀 더 편할 것 같아서 프림으로 구현했다.

def prim(start):
    mst = []
    heap = []
    total = 0
    visited = {i: False for i in range(2, group + 1)}
    heap.append((0, start))
    while heap:
        weight, node = heapq.heappop(heap)
        if visited[node] == True:
            continue
        visited[node] = True
        mst.append(node)
        total += weight
        for next, w in graph[node].items():
            heapq.heappush(heap, (w, next))
    if len(mst) < group - 1:
        return -1
    return total

전체 코드

import heapq

n, m = map(int, input().split())

arr = [list(map(int, input().split())) for _ in range(n)]

group = 1

d = [(0, -1), (0, 1), (-1, 0), (1, 0)]


def inRange(x, y):
    return x >= 0 and x < n and y >= 0 and y < m


stack = []

for i in range(n):
    for j in range(m):
        if arr[i][j] == 1:
            group += 1
            stack.append((i, j, group))
        while stack:
            x, y, group = stack.pop()
            arr[x][y] = group
            for dx, dy in d:
                newx, newy = x + dx, y + dy
                if inRange(newx, newy) and arr[newx][newy] == 1:
                    stack.append((newx, newy, group))


graph = {i: {} for i in range(2, group + 1)}


def addEdge(a, b, d):
    if d < 2:
        return
    if graph[a].get(b) == None:
        graph[a][b] = d
    else:
        graph[a][b] = min(graph[a][b], d)
    if graph[b].get(a) == None:
        graph[b][a] = d
    else:
        graph[b][a] = min(graph[b][a], d)


for i in range(n):
    currentGroup = 0
    distance = 0
    for j in range(m):
        if arr[i][j] == 0:
            distance += 1
        elif arr[i][j] == currentGroup:
            distance = 0
        elif arr[i][j] != currentGroup:
            if currentGroup != 0:
                addEdge(currentGroup, arr[i][j], distance)
            currentGroup = arr[i][j]
            distance = 0

for j in range(m):
    currentGroup = 0
    distance = 0
    for i in range(n):
        if arr[i][j] == 0:
            distance += 1
        elif arr[i][j] == currentGroup:
            distance = 0
        elif arr[i][j] != currentGroup:
            if currentGroup != 0:
                addEdge(currentGroup, arr[i][j], distance)
            currentGroup = arr[i][j]
            distance = 0


def edges(x, y):
    for i in range(x):
        currentGroup = 0
        distance = 0
        for j in range(y):
            if arr[i][j] == 0:
                distance += 1
            elif arr[i][j] == currentGroup:
                distance = 0
            elif arr[i][j] != currentGroup:
                if currentGroup != 0:
                    addEdge(currentGroup, arr[i][j], distance)
                currentGroup = arr[i][j]
                distance = 0


edges(n, m)
edges(m, n)


def prim(start):
    mst = []
    heap = []
    total = 0
    visited = {i: False for i in range(2, group + 1)}
    heap.append((0, start))
    while heap:
        weight, node = heapq.heappop(heap)
        if visited[node] == True:
            continue
        visited[node] = True
        mst.append(node)
        total += weight
        for next, w in graph[node].items():
            heapq.heappush(heap, (w, next))
    if len(mst) < group - 1:
        return -1
    return total


print(prim(2))